/**
 * Copyright 2023 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_ENGINE_OPTIMIZER_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_ENGINE_OPTIMIZER_H_

#include "ir/func_graph.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/optimizer/optimizer.h"
#include "mindspore/core/symbolic_shape/symbol.h"

namespace mindspore {
namespace opt {
namespace irpass {
class SymbolEngineBuilder {
 public:
  explicit SymbolEngineBuilder(bool only_dynshape_graph = true) : only_dynshape_graph_(only_dynshape_graph) {}
  ~SymbolEngineBuilder() = default;
  bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &opt);

 protected:
  bool HasDynamicShapeNode(const OptimizerPtr &opt) const;
  bool only_dynshape_graph_{true};  // If true, only build SymbolEngine when dynamic shape node exists.
};

/**
 * Eliminate the ShapeCalc-Reduce-Reshape pattern generated by BroadcastGradientArgs.
 *
 * %5 = Add(a, b)  // when shape of "a" is equal to shape of "%5"
 * ...
 * %10 = ShapeCalc(a, b)   // backward op of "%5-Add".
 * %11 = TupleGetItem(%10, 0)
 * %12 = ReduceSum(dout, %11)
 * %13 = Shape(a)
 * %14 = Reshape(%12, %13)
 * %15 = op(%14)
 * --->
 * %5 = Add(a, b)
 * ...
 * %10 = op(dout)
 *
 * There may be another `TupleGetItem(%10, 1)` branch. when both branches are eliminated together, the "ShapeCalc"
 * is eliminated.
 */
class ElimShapeCalcOnBroadcastArgsGrad : public AnfVisitor {
 public:
  AnfNodePtr operator()(const OptimizerPtr &opt, const AnfNodePtr &node) override;

 protected:
  bool Check(const OptimizerPtr &opt, const AnfNodePtr &shape_calc, size_t input_index);
  bool CheckSymbolEqual(const ListSymbolPtr &input_shape, const ListSymbolPtr &output_shape, size_t shift);
};

// Some ops like ReduceSum or Reshape, if the input shape and output shape are the same (in symbolic shape), it means
// that this op is not effective in running, so we can eliminate it.
class ElimNotEffectiveNode : public AnfVisitor {
 public:
  AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};

// the symbolic value of "shape" is static or has only one "-1", replace the "shape" to a const tensor.
class OptReshape : public AnfVisitor {
 public:
  AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};

/**
 * Fold the input cnode when the symbolic value is constant value.
 *
 * example:
 * %0 = ShapeCalc(p, ()) // the ShapeCalc has two output
 * %1 = TupleGetItem(%0, 1) // the symbolic value of item 1 is const.
 * %2 = Tile(p, %1)
 * -->
 * %2 = Tile(p, const_value)
 */
class FoldConstSymbol : public AnfVisitor {
 public:
  AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};

class ShapeOpCse {
 public:
  ShapeOpCse() = default;
  ~ShapeOpCse() = default;
  bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
};
}  // namespace irpass
}  // namespace opt
}  // namespace mindspore
#endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_ENGINE_OPTIMIZER_H_
